{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to XGBoost Spark3.0 with GPU\n",
    "\n",
    "Agaricus is an example of XGBoost classifier for multiple classification. This notebook will show you how to load data, train the xgboost model. Comparing to original XGBoost Spark code, there're only one API difference.\n",
    "\n",
    "## Load libraries\n",
    "First load some common libraries will be used by both GPU version and CPU version XGBoost."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier, XGBoostClassificationModel}\n",
    "import org.apache.spark.sql.SparkSession\n",
    "import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\n",
    "import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Besides CPU version requires some extra libraries, such as:\n",
    "\n",
    "```scala\n",
    "import org.apache.spark.ml.feature.VectorAssembler\n",
    "import org.apache.spark.sql.functions._\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set the dataset paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "lastException = null\n",
       "dataRoot = /data\n",
       "trainPath = /data/agaricus/csv/train/\n",
       "evalPath = /data/agaricus/csv/test/\n",
       "transPath = /data/agaricus/csv/test/\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "/data/agaricus/csv/test/"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "// You need to update them to your real paths!\n",
    "val dataRoot = sys.env.getOrElse(\"DATA_ROOT\", \"/data\")\n",
    "val trainPath = dataRoot + \"/agaricus/csv/train/\"\n",
    "val evalPath  = dataRoot + \"/agaricus/csv/test/\"\n",
    "val transPath = dataRoot + \"/agaricus/csv/test/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build the schema of the dataset\n",
    "\n",
    "For agaricus example, the data has 126 dimensions, being named as \"feature_0\", \"feature_1\" ... \"feature_125\". The schema will be used to load data in the future."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "labelName = label\n",
       "dataSchema = StructType(StructField(label,DoubleType,true), StructField(feature_0,DoubleType,true), StructField(feature_1,DoubleType,true), StructField(feature_2,DoubleType,true), StructField(feature_3,DoubleType,true), StructField(feature_4,DoubleType,true), StructField(feature_5,DoubleType,true), StructField(feature_6,DoubleType,true), StructField(feature_7,DoubleType,true), StructField(feature_8,DoubleType,true), StructField(feature_9,DoubleType,true), StructField(feature_10,DoubleType,true), StructField(feature_11,DoubleType,true), StructField(feature_12,DoubleType,true), StructField(feature_13,DoubleType,true), StructFiel...\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "columnNames: (length: Int)List[String]\n",
       "schema: (length: Int)org.apache.spark.sql.types.StructType\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "StructType(StructField(label,DoubleType,true), StructField(feature_0,DoubleType,true), StructField(feature_1,DoubleType,true), StructField(feature_2,DoubleType,true), StructField(feature_3,DoubleType,true), StructField(feature_4,DoubleType,true), StructField(feature_5,DoubleType,true), StructField(feature_6,DoubleType,true), StructField(feature_7,DoubleType,true), StructField(feature_8,DoubleType,true), StructField(feature_9,DoubleType,true), StructField(feature_10,DoubleType,true), StructField(feature_11,DoubleType,true), StructField(feature_12,DoubleType,true), StructField(feature_13,DoubleType,true), StructFiel..."
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val labelName = \"label\"\n",
    "def columnNames(length: Int): List[String] =\n",
    "  0.until(length).map(i => s\"feature_$i\").toList.+:(labelName)\n",
    "\n",
    "def schema(length: Int): StructType =\n",
    "  StructType(columnNames(length).map(n => StructField(n, DoubleType)))\n",
    "\n",
    "val dataSchema = schema(126)\n",
    "\n",
    "// Build the column name list for features.\n",
    "val featureCols = dataSchema.filter(_.name != labelName).map(_.name).toArray"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a new spark session and load data\n",
    "\n",
    "A new spark session should be created to continue all the following spark operations.\n",
    "\n",
    "NOTE: in this notebook, the dependency jars have been loaded when installing toree kernel. Alternatively the jars can be loaded into notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called just after a new spark session is created. Do it as below:\n",
    "\n",
    "```scala\n",
    "import org.apache.spark.sql.SparkSession\n",
    "val spark = SparkSession.builder().appName(\"agaricus-GPU\").getOrCreate\n",
    "%AddJar file:/data/libs/rapids-4-spark-XXX.jar\n",
    "%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\n",
    "%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\n",
    "// ...\n",
    "```\n",
    "\n",
    "##### Please note the new jar \"rapids-4-spark-XXX.jar\" is only needed for GPU version, you can not add it to dependence list for CPU version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "sparkSession = org.apache.spark.sql.SparkSession@3886ba44\n",
       "dataReader = org.apache.spark.sql.DataFrameReader@5c8be07f\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "org.apache.spark.sql.DataFrameReader@5c8be07f"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "// Build the spark session and data reader as usual\n",
    "val sparkSession = SparkSession.builder.appName(\"agaricus-gpu\").getOrCreate\n",
    "val dataReader = sparkSession.read.option(\"header\", true).schema(dataSchema)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "trainSet = [label: double, feature_0: double ... 125 more fields]\n",
       "evalSet = [label: double, feature_0: double ... 125 more fields]\n",
       "transSet = [label: double, feature_0: double ... 125 more fields]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[label: double, feature_0: double ... 125 more fields]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "// load all the dataset\n",
    "val trainSet = dataReader.csv(trainPath)\n",
    "val evalSet  = dataReader.csv(evalPath)\n",
    "val transSet = dataReader.csv(transPath)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set XGBoost parameters and build a XGBoostClassifier\n",
    "\n",
    "For CPU version, `num_workers` is recommended being equal to the number of CPU cores, while for GPU version, it should be set to the number of GPUs in Spark cluster.\n",
    "\n",
    "Besides the `tree_method` for CPU version is also different from that for GPU version. Now only \"gpu_hist\" is supported for training on GPU.\n",
    "\n",
    "```scala\n",
    "// difference in parameters\n",
    "  \"num_workers\" -> 12,\n",
    "  \"tree_method\" -> \"hist\",\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "paramMap = Map(num_workers -> 1, tree_method -> gpu_hist, num_round -> 100)\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Map(num_workers -> 1, tree_method -> gpu_hist, num_round -> 100)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "// build XGBoost classifier\n",
    "val paramMap = Map(\n",
    "  \"num_workers\" -> 1,\n",
    "  \"tree_method\" -> \"gpu_hist\",\n",
    "  \"num_round\" -> 100\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "xgbClassifier = xgbc_57e2d7fc657a\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "xgbc_57e2d7fc657a"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val xgbClassifier  = new XGBoostClassifier(paramMap)\n",
    "  .setLabelCol(labelName)\n",
    "  .setFeaturesCol(featureCols)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Benchmark and train\n",
    "The object `benchmark` is used to compute the elapsed time of some operations.\n",
    "\n",
    "Training with evaluation dataset is also supported, the same as CPU version's behavior:\n",
    "\n",
    "* Call API `setEvalDataset` after initializing an XGBoostClassifier\n",
    "\n",
    "```scala\n",
    "xgbClassifier.setEvalDataset(evalSet)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "xgbc_57e2d7fc657a"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xgbClassifier.setEvalDataset(evalSet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defined object Benchmark\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "object Benchmark {\n",
    "  def time[R](phase: String)(block: => R): (R, Float) = {\n",
    "    val t0 = System.currentTimeMillis\n",
    "    val result = block // call-by-name\n",
    "    val t1 = System.currentTimeMillis\n",
    "    println(\"Elapsed time [\" + phase + \"]: \" + ((t1 - t0).toFloat / 1000) + \"s\")\n",
    "    (result, (t1 - t0).toFloat / 1000)\n",
    "  }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "------ Training ------\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=34739, DMLC_NUM_WORKER=1}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "xgbClassificationModel = xgbc_57e2d7fc657a\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Elapsed time [train]: 11.177s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "xgbc_57e2d7fc657a"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "// start training\n",
    "println(\"\\n------ Training ------\")\n",
    "val (xgbClassificationModel, _) = Benchmark.time(\"train\") {\n",
    "  xgbClassifier.fit(trainSet)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transformation and evaluation\n",
    "Here uses `transSet` to evaluate our model and prints some useful columns to show our prediction result. After that `MulticlassClassificationEvaluator` is used to calculate an overall accuracy of our predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "------ Transforming ------\n",
      "Elapsed time [transform]: 2.51s\n",
      "+-----+--------------------+--------------------+----------+\n",
      "|label|       rawPrediction|         probability|prediction|\n",
      "+-----+--------------------+--------------------+----------+\n",
      "|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  0.0|[-4.4405460357666...|[0.99995559453964...|       0.0|\n",
      "|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "+-----+--------------------+--------------------+----------+\n",
      "only showing top 10 rows\n",
      "\n",
      "\n",
      "------Accuracy of Evaluation------\n",
      "accuracy == 0.9069677632722861\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "results = [label: double, feature_0: double ... 128 more fields]\n",
       "evaluator = MulticlassClassificationEvaluator: uid=mcEval_8f89b3a17d4b, metricName=f1, metricLabel=0.0, beta=1.0, eps=1.0E-15\n",
       "accuracy = 0.9069677632722861\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "0.9069677632722861"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "// start transform\n",
    "println(\"\\n------ Transforming ------\")\n",
    "val (results, _) = Benchmark.time(\"transform\") {\n",
    "  val ret = xgbClassificationModel.transform(transSet).cache()\n",
    "  ret.foreachPartition((_: Iterator[_]) => ())\n",
    "  ret\n",
    "}\n",
    "results.select(labelName, \"rawPrediction\", \"probability\", \"prediction\").show(10)\n",
    "\n",
    "println(\"\\n------Accuracy of Evaluation------\")\n",
    "val evaluator = new MulticlassClassificationEvaluator()\n",
    "evaluator.setLabelCol(labelName)\n",
    "val accuracy = evaluator.evaluate(results)\n",
    "\n",
    "println(s\"accuracy == $accuracy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save the model to disk and load model\n",
    "Save the model to disk and then load it to memory. After that use the loaded model to do a new prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Elapsed time [transform2]: 0.069s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "modelFromDisk = xgbc_57e2d7fc657a\n",
       "results2 = [label: double, feature_0: double ... 128 more fields]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+--------------------+--------------------+----------+\n",
      "|label|       rawPrediction|         probability|prediction|\n",
      "+-----+--------------------+--------------------+----------+\n",
      "|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  0.0|[-4.4405460357666...|[0.99995559453964...|       0.0|\n",
      "|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  0.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "|  1.0|[-0.9999903440475...|[9.65595245361328...|       1.0|\n",
      "+-----+--------------------+--------------------+----------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[label: double, feature_0: double ... 128 more fields]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xgbClassificationModel.write.overwrite.save(dataRoot + \"/model/agaricus\")\n",
    "\n",
    "val modelFromDisk = XGBoostClassificationModel.load(dataRoot + \"/model/agaricus\")\n",
    "val (results2, _) = Benchmark.time(\"transform2\") {\n",
    "  modelFromDisk.transform(transSet)\n",
    "}\n",
    "results2.select(labelName, \"rawPrediction\", \"probability\", \"prediction\").show(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "sparkSession.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "XGBoost4j-Spark - Scala",
   "language": "scala",
   "name": "XGBoost4j-Spark_scala"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".scala",
   "mimetype": "text/x-scala",
   "name": "scala",
   "pygments_lexer": "scala",
   "version": "2.12.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
